{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['data_batch_1', 'data_batch_2', 'data_batch_5', 'data_batch_4', 'data_batch_3', 'batches.meta', 'readme.html', 'test_batch']\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import os\n",
    "CIFAR_DIR = \"/home/herb/code/data/dataset/cf10data\"\n",
    "print(os.listdir(CIFAR_DIR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(filename):\n",
    "    \"\"\"\n",
    "    read data from data file\n",
    "    \"\"\"\n",
    "    with open(filename,'rb') as f:\n",
    "        data = pickle.load(f,encoding='bytes')\n",
    "        return data[b'data'],data[b'labels']\n",
    "    \n",
    "class CifarData:\n",
    "    def __init__(self,filenames,need_shuffle):\n",
    "        all_data = []\n",
    "        all_labels = []\n",
    "        for filename in filenames:\n",
    "            data,labels = load_data(filename)\n",
    "            all_data.append(data)\n",
    "            all_labels.append(labels)\n",
    "                    \n",
    "        self._data = np.vstack(all_data)\n",
    "        # 数据归一化\n",
    "        self._data = self._data/127.5 -1\n",
    "        self._labels = np.hstack(all_labels)\n",
    "        self._num_examples = self._data.shape[0]\n",
    "        self._need_shuffle = need_shuffle\n",
    "        self._indicator = 0\n",
    "        if self._need_shuffle:\n",
    "            self._shuffle_data()\n",
    "            \n",
    "    def _shuffle_data(self):\n",
    "        p = np.random.permutation(self._num_examples)\n",
    "        self._data= self._data[p]\n",
    "        self._labels = self._labels[p]\n",
    "        \n",
    "    def next_batch(self,batch_size):\n",
    "        end_indicator = self._indicator + batch_size\n",
    "        if end_indicator > self._num_examples:\n",
    "            if self._need_shuffle:\n",
    "                self._shuffle_data()\n",
    "                self._indicator = 0\n",
    "                end_indicator = batch_size\n",
    "            else:\n",
    "                raise Exception(\"error\")\n",
    "        if end_indicator > self._num_examples:\n",
    "            raise Exception(\"batch size is larger than all examples\")\n",
    "        \n",
    "        batch_data = self._data[self._indicator:end_indicator]\n",
    "        batch_lables = self._labels[self._indicator:end_indicator]\n",
    "        self._indicator = end_indicator\n",
    "        return batch_data, batch_lables\n",
    "    \n",
    "\n",
    "train_filenames = [os.path.join(CIFAR_DIR,'data_batch_%d'%i)for i in range(1,6)]\n",
    "test_filenames = [os.path.join(CIFAR_DIR,'test_batch')]\n",
    "\n",
    "train_data = CifarData(train_filenames,True)\n",
    "test_data = CifarData(test_filenames,False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/herb/program/anaconda3/envs/cifar10/lib/python3.6/site-packages/tensorflow/python/ops/losses/losses_impl.py:121: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "# 占位符\n",
    "x = tf.placeholder(tf.float32,[None,3072])\n",
    "y = tf.placeholder(tf.int64,[None])\n",
    "\n",
    "# (3072,10)\n",
    "w = tf.get_variable('w', [x.get_shape()[-1],10],initializer=tf.random_normal_initializer(0,1))\n",
    "\n",
    "#(10,)\n",
    "b = tf.get_variable('b',[10],initializer=tf.constant_initializer(0.0))\n",
    "# [None,3072] * [3072,10] = [None,10]\n",
    "y_ = tf.matmul(x, w)+b\n",
    "\"\"\"\n",
    "平方差损失函数\n",
    "p_y = tf.nn.softmax(y_)\n",
    "\n",
    "y_one_hot = tf.one_hot(y,10,dtype=tf.float32)\n",
    "\n",
    "loss = tf.reduce_mean(tf.square(y_one_hot - p_y))\n",
    "#[None,1]\n",
    "\"\"\"\n",
    "\n",
    "loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=y_)\n",
    "\n",
    "# indices\n",
    "predict = tf.argmax(y_, 1)\n",
    "\n",
    "correct_prediction = tf.equal(predict,y)\n",
    "\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float64))\n",
    "\n",
    "\n",
    "with tf.name_scope('train_op'):\n",
    "    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train step: 0, loss: 42.938697814941406 acc:0.05\n",
      "---------------0.12650000000000003\n",
      "Train step: 500, loss: 10.159871101379395 acc:0.15\n",
      "Train step: 1000, loss: 9.623239517211914 acc:0.3\n",
      "Train step: 1500, loss: 6.651302337646484 acc:0.3\n",
      "Train step: 2000, loss: 8.594182968139648 acc:0.3\n",
      "Train step: 2500, loss: 9.62916374206543 acc:0.3\n",
      "Train step: 3000, loss: 7.005529880523682 acc:0.3\n",
      "Train step: 3500, loss: 9.977266311645508 acc:0.25\n",
      "Train step: 4000, loss: 4.447190284729004 acc:0.55\n",
      "Train step: 4500, loss: 8.87260627746582 acc:0.15\n",
      "Train step: 5000, loss: 6.547215461730957 acc:0.15\n",
      "---------------0.2595\n",
      "Train step: 5500, loss: 9.412397384643555 acc:0.15\n",
      "Train step: 6000, loss: 6.996195316314697 acc:0.2\n",
      "Train step: 6500, loss: 5.99907112121582 acc:0.3\n",
      "Train step: 7000, loss: 7.944432735443115 acc:0.1\n",
      "Train step: 7500, loss: 6.187597751617432 acc:0.45\n",
      "Train step: 8000, loss: 9.315279006958008 acc:0.1\n",
      "Train step: 8500, loss: 3.9547340869903564 acc:0.35\n",
      "Train step: 9000, loss: 5.934263229370117 acc:0.2\n",
      "Train step: 9500, loss: 3.906879425048828 acc:0.25\n",
      "Train step: 10000, loss: 2.7826037406921387 acc:0.35\n",
      "---------------0.28\n",
      "Train step: 10500, loss: 3.252213954925537 acc:0.4\n",
      "Train step: 11000, loss: 5.924036979675293 acc:0.25\n",
      "Train step: 11500, loss: 4.927115440368652 acc:0.4\n",
      "Train step: 12000, loss: 6.605323791503906 acc:0.2\n",
      "Train step: 12500, loss: 3.549969434738159 acc:0.3\n",
      "Train step: 13000, loss: 6.69631290435791 acc:0.2\n",
      "Train step: 13500, loss: 5.180264472961426 acc:0.35\n",
      "Train step: 14000, loss: 4.816896915435791 acc:0.25\n",
      "Train step: 14500, loss: 4.903558254241943 acc:0.25\n",
      "Train step: 15000, loss: 5.71185302734375 acc:0.35\n",
      "---------------0.2795\n",
      "Train step: 15500, loss: 3.76466703414917 acc:0.4\n",
      "Train step: 16000, loss: 4.789539813995361 acc:0.2\n",
      "Train step: 16500, loss: 3.5592784881591797 acc:0.35\n",
      "Train step: 17000, loss: 5.787440299987793 acc:0.2\n",
      "Train step: 17500, loss: 4.853079795837402 acc:0.25\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-27b4e6179f4c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0mbatch_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_labels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnext_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m         \u001b[0mloss_val\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0macc_val\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrain_op\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbatch_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbatch_labels\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m\u001b[0;36m500\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/program/anaconda3/envs/cifar10/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    948\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    949\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 950\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    951\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    952\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/program/anaconda3/envs/cifar10/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1156\u001b[0m     \u001b[0;31m# Create a fetch handler to take care of the structure of fetches.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1157\u001b[0m     fetch_handler = _FetchHandler(\n\u001b[0;32m-> 1158\u001b[0;31m         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)\n\u001b[0m\u001b[1;32m   1159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1160\u001b[0m     \u001b[0;31m# Run request and get response.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/program/anaconda3/envs/cifar10/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, graph, fetches, feeds, feed_handles)\u001b[0m\n\u001b[1;32m    493\u001b[0m            fetch.op.type == 'GetSessionHandleV2')):\n\u001b[1;32m    494\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetch_handles\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 495\u001b[0;31m     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_final_fetches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetches\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    496\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    497\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_assert_fetchable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/program/anaconda3/envs/cifar10/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    493\u001b[0m            fetch.op.type == 'GetSessionHandleV2')):\n\u001b[1;32m    494\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetch_handles\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 495\u001b[0;31m     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_final_fetches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetches\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    496\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    497\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_assert_fetchable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/program/anaconda3/envs/cifar10/lib/python3.6/site-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36m__hash__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    639\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m__hash__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    640\u001b[0m     \u001b[0;31m# Necessary to support Python's collection membership operators\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 641\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    642\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    643\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m__eq__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "init = tf.global_variables_initializer()\n",
    "batch_size = 20\n",
    "train_steps = 100000\n",
    "test_steps = 100\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    for i in range(train_steps):\n",
    "        batch_data,batch_labels = train_data.next_batch(batch_size)\n",
    "        loss_val,acc_val,_ = sess.run([loss,accuracy,train_op],feed_dict={x:batch_data,y:batch_labels})\n",
    "        \n",
    "        if i %500 == 0:\n",
    "            print(\"Train step: {}, loss: {} acc:{}\".format(i,loss_val,acc_val))\n",
    "            \n",
    "        if i %5000 == 0:\n",
    "            test_data = CifarData(test_filenames, False)\n",
    "            all_test_acc_val = []\n",
    "            for j in range(test_steps):\n",
    "                test_batch_data,test_batch_label = test_data.next_batch(batch_size)\n",
    "                test_acc_val = sess.run([accuracy], feed_dict = {x:test_batch_data,y:test_batch_label})\n",
    "                all_test_acc_val.append(test_acc_val)\n",
    "            \n",
    "            test_acc = np.mean(all_test_acc_val)\n",
    "            print(\"---------------{}\".format(test_acc))\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
